import os
import glob
import numpy as np
import pandas as pd
import tkinter as tk
from tkinter import filedialog, messagebox
from tkinter import ttk

import mne
from mne.preprocessing import ICA

# Matplotlib non-blocking show (version-agnostic)
import matplotlib.pyplot as plt


# -----------------------------
# DEFAULTS
# -----------------------------
DEFAULT_SFREQ = 256.0
DEFAULT_CH_NAMES = ["POz", "Fz", "Cz", "C3", "C4", "F3", "F4", "P3", "P4"]

DEFAULT_HPASS = 0.2
DEFAULT_LPASS = 40.0
DEFAULT_NOTCH_LO = 55
DEFAULT_NOTCH_HI = 65

# IIR filters (safe for short recordings)
IIR_BANDPASS_PARAMS = dict(order=4, ftype="butter")
IIR_NOTCH_PARAMS = dict(order=4, ftype="butter")

DEFAULT_TMIN = -0.2
DEFAULT_TMAX = 1.0
DEFAULT_BASELINE_ON = False

DEFAULT_ICA_N_COMPONENTS = 0.99
DEFAULT_ICA_METHOD = "fastica"
DEFAULT_ICA_RANDOM_STATE = 97

# raw txt is in microvolts
UV_TO_V = 1e-6
EVENT_LATENCY_IS_MS = True


# -----------------------------
# IO / PAIRING
# -----------------------------
def find_pairs(folder: str):
    raw_files = sorted(glob.glob(os.path.join(folder, "* raw*.txt")))
    event_files = sorted(glob.glob(os.path.join(folder, "* event words*.txt")))

    event_map = {}
    for ef in event_files:
        base = os.path.basename(ef)
        name = base.split(" event words")[0].strip()
        event_map[name.lower()] = ef

    pairs = []
    for rf in raw_files:
        base = os.path.basename(rf)
        name = base.split(" raw")[0].strip()
        ef = event_map.get(name.lower(), None)
        if ef is None:
            continue
        pairs.append((name, rf, ef))
    return pairs


def read_9col_raw_txt(raw_path: str) -> np.ndarray:
    df = pd.read_csv(raw_path, sep=None, engine="python", header=None)
    if df.shape[1] != 9:
        raise ValueError(f"{os.path.basename(raw_path)} has {df.shape[1]} columns; expected 9.")
    data_uv = df.to_numpy(dtype=float).T  # (9, n_samples)
    return data_uv * UV_TO_V


def read_event_words_txt(event_path: str) -> pd.DataFrame:
    ev = pd.read_csv(event_path, sep=r"\s+", engine="python")
    ev.columns = [c.strip().lower() for c in ev.columns]
    if "latency" not in ev.columns or "type" not in ev.columns:
        raise ValueError(f"{os.path.basename(event_path)} must have columns: latency type")
    ev["type"] = ev["type"].astype(str)
    return ev


def build_events(ev_df: pd.DataFrame, sfreq: float, n_samples: int):
    labels = sorted(ev_df["type"].unique())
    event_id = {lab: i + 1 for i, lab in enumerate(labels)}

    lat = ev_df["latency"].to_numpy(dtype=float)
    if EVENT_LATENCY_IS_MS:
        sample = np.rint((lat / 1000.0) * sfreq).astype(int)
    else:
        sample = np.rint(lat).astype(int)

    good = (sample >= 0) & (sample < n_samples)
    dropped = int(np.sum(~good))
    ev_df2 = ev_df.loc[good].copy()
    sample2 = sample[good]
    codes = ev_df2["type"].map(event_id).to_numpy(dtype=int)

    events = np.zeros((len(sample2), 3), dtype=int)
    events[:, 0] = sample2
    events[:, 2] = codes
    return events, event_id, ev_df2, dropped


# -----------------------------
# MNE CONSTRUCTION / FILTERS
# -----------------------------
def make_raw_mne(data_ch_by_samp_v: np.ndarray, sfreq: float, ch_names):
    info = mne.create_info(ch_names=ch_names, sfreq=sfreq, ch_types=["eeg"] * len(ch_names))
    raw = mne.io.RawArray(data_ch_by_samp_v, info, verbose="ERROR")
    try:
        montage = mne.channels.make_standard_montage("standard_1020")
        raw.set_montage(montage, on_missing="ignore")
    except Exception:
        pass
    return raw


def apply_bandpass_iir(raw, hpass, lpass):
    return raw.copy().filter(
        l_freq=hpass,
        h_freq=lpass,
        method="iir",
        iir_params=IIR_BANDPASS_PARAMS,
        phase="zero",
        verbose="ERROR",
    )


def apply_notch_iir_one_by_one(raw, notch_freqs):
    raw_out = raw.copy()
    for f in notch_freqs:
        raw_out.notch_filter(
            freqs=float(f),
            method="iir",
            iir_params=IIR_NOTCH_PARAMS,
            phase="zero",
            verbose="ERROR",
        )
    return raw_out


# -----------------------------
# ICA REVIEW DIALOG (NO block=)
# -----------------------------
class ICAReviewDialog(tk.Toplevel):
    """
    Choose components to exclude via checklist.
    Preview opens MNE figures; we call plt.show(block=False) for non-blocking.
    """
    def __init__(self, parent, ica: ICA, raw_for_ica: mne.io.BaseRaw):
        super().__init__(parent)
        self.title("ICA Review")
        self.resizable(True, True)

        self.ica = ica
        self.raw_for_ica = raw_for_ica
        self.exclude = []
        self.action = "continue"  # or "skip"

        n_components = ica.n_components_

        container = ttk.Frame(self, padding=10)
        container.pack(fill="both", expand=True)

        ttk.Label(
            container,
            text="Select ICA components to EXCLUDE. Use Preview buttons if needed, then click Continue.",
            wraplength=600,
        ).pack(anchor="w", pady=(0, 8))

        # Scrollable checklist
        canvas = tk.Canvas(container, height=280)
        scrollbar = ttk.Scrollbar(container, orient="vertical", command=canvas.yview)
        scroll_frame = ttk.Frame(canvas)

        scroll_frame.bind("<Configure>", lambda e: canvas.configure(scrollregion=canvas.bbox("all")))
        canvas.create_window((0, 0), window=scroll_frame, anchor="nw")
        canvas.configure(yscrollcommand=scrollbar.set)

        canvas.pack(side="left", fill="both", expand=True)
        scrollbar.pack(side="right", fill="y")

        self.vars = []
        for i in range(n_components):
            v = tk.BooleanVar(value=False)
            ttk.Checkbutton(scroll_frame, text=f"Component {i}", variable=v).pack(anchor="w")
            self.vars.append(v)

        # Buttons
        btns = ttk.Frame(container)
        btns.pack(fill="x", pady=(10, 0))

        ttk.Button(btns, text="Select none", command=self.select_none).pack(side="left")
        ttk.Button(btns, text="Select all", command=self.select_all).pack(side="left", padx=6)

        ttk.Button(btns, text="Preview components (topos)", command=self.preview_components).pack(side="left", padx=12)
        ttk.Button(btns, text="Preview sources", command=self.preview_sources).pack(side="left")

        ttk.Button(btns, text="Skip participant", command=self.skip).pack(side="right")
        ttk.Button(btns, text="Continue", command=self.cont).pack(side="right", padx=6)

        self.protocol("WM_DELETE_WINDOW", self.cont)

        # modal
        self.transient(parent)
        self.grab_set()
        self.wait_visibility()
        self.focus_set()
        self.wait_window()

    def select_none(self):
        for v in self.vars:
            v.set(False)

    def select_all(self):
        for v in self.vars:
            v.set(True)

    def preview_components(self):
        # Some MNE versions don't support block= here; use matplotlib instead
        self.ica.plot_components(show=True)
        plt.show(block=False)

    def preview_sources(self):
        self.ica.plot_sources(self.raw_for_ica, show=True)
        plt.show(block=False)

    def cont(self):
        self.exclude = [i for i, v in enumerate(self.vars) if v.get()]
        self.action = "continue"
        self.grab_release()
        self.destroy()

    def skip(self):
        self.exclude = []
        self.action = "skip"
        self.grab_release()
        self.destroy()


# -----------------------------
# PROCESS ONE PARTICIPANT
# -----------------------------
def process_one(parent, name, raw_path, event_path, out_dir, params):
    sfreq = params["sfreq"]
    ch_names = params["ch_names"]
    hpass = params["hpass"]
    lpass = params["lpass"]
    notch_freqs = params["notch_freqs"]
    tmin = params["tmin"]
    tmax = params["tmax"]
    baseline_on = params["baseline_on"]

    # Load
    data_v = read_9col_raw_txt(raw_path)
    n_samples = data_v.shape[1]
    raw = make_raw_mne(data_v, sfreq=sfreq, ch_names=ch_names)

    # Filter
    raw_bp = apply_bandpass_iir(raw, hpass, lpass)
    raw_filt = apply_notch_iir_one_by_one(raw_bp, notch_freqs)

    # ICA fit (fit on 1 Hz HP copy)
    raw_for_ica = raw_filt.copy().filter(
        l_freq=1.0,
        h_freq=lpass,
        method="iir",
        iir_params=IIR_BANDPASS_PARAMS,
        phase="zero",
        verbose="ERROR",
    )

    ica = ICA(
        n_components=params["ica_n_components"],
        method=params["ica_method"],
        random_state=params["ica_random_state"],
        max_iter="auto",
    )
    ica.fit(raw_for_ica, verbose="ERROR")

    # Review dialog
    dlg = ICAReviewDialog(parent, ica, raw_for_ica)
    if dlg.action == "skip":
        return {"name": name, "skipped": True, "reason": "ICA skipped by user or dialog closed"}

    ica.exclude = dlg.exclude

    # Apply ICA
    raw_clean = raw_filt.copy()
    ica.apply(raw_clean)

    # Epoch
    ev_df = read_event_words_txt(event_path)
    events, event_id, ev_df_inrange, dropped = build_events(ev_df, sfreq, n_samples)

    baseline = (tmin, 0.0) if baseline_on else None

    epochs = mne.Epochs(
        raw_clean,
        events=events,
        event_id=event_id,
        tmin=tmin,
        tmax=tmax,
        baseline=baseline,
        preload=True,
        reject_by_annotation=False,
        verbose="ERROR",
    )

    # Save
    os.makedirs(out_dir, exist_ok=True)
    raw_out = os.path.join(out_dir, f"{name}_cleaned_raw.fif")
    epo_out = os.path.join(out_dir, f"{name}_cleaned-epo.fif")
    npy_out = os.path.join(out_dir, f"{name}_cleaned_epochs.npy")
    ica_out = os.path.join(out_dir, f"{name}_ica.fif")
    summ_out = os.path.join(out_dir, f"{name}_event_summary.csv")

    raw_clean.save(raw_out, overwrite=True, verbose="ERROR")
    epochs.save(epo_out, overwrite=True, verbose="ERROR")
    np.save(npy_out, epochs.get_data())
    ica.save(ica_out)

    counts = ev_df_inrange["type"].value_counts().sort_index()
    pd.DataFrame({"event": counts.index, "count": counts.values}).to_csv(summ_out, index=False)

    return {
        "name": name,
        "skipped": False,
        "n_epochs": int(len(epochs)),
        "n_dropped_events": int(dropped),
        "excluded_components": ",".join(map(str, dlg.exclude)) if dlg.exclude else "",
        "event_id": str(event_id),
    }


# -----------------------------
# MAIN GUI APP
# -----------------------------
class App(tk.Tk):
    def __init__(self):
        super().__init__()
        self.title("EEG Batch Preprocess (Raw txt + Event words txt)")
        self.resizable(False, False)

        self.in_dir = tk.StringVar(value="")
        self.out_dir = tk.StringVar(value="")

        self.hpass = tk.DoubleVar(value=DEFAULT_HPASS)
        self.lpass = tk.DoubleVar(value=DEFAULT_LPASS)
        self.notch_lo = tk.IntVar(value=DEFAULT_NOTCH_LO)
        self.notch_hi = tk.IntVar(value=DEFAULT_NOTCH_HI)

        self.tmin = tk.DoubleVar(value=DEFAULT_TMIN)
        self.tmax = tk.DoubleVar(value=DEFAULT_TMAX)
        self.baseline_on = tk.BooleanVar(value=DEFAULT_BASELINE_ON)

        self.status = tk.StringVar(value="Ready.")
        self._build()

    def _build(self):
        pad = {"padx": 10, "pady": 6}
        frm = ttk.Frame(self, padding=10)
        frm.pack(fill="both", expand=True)

        ttk.Label(frm, text="Input folder (contains all '* raw*.txt' and '* event words*.txt')").grid(row=0, column=0, sticky="w")
        ttk.Entry(frm, textvariable=self.in_dir, width=60).grid(row=1, column=0, sticky="w", **pad)
        ttk.Button(frm, text="Browse…", command=self.pick_in_dir).grid(row=1, column=1, sticky="w")

        ttk.Label(frm, text="Output folder").grid(row=2, column=0, sticky="w")
        ttk.Entry(frm, textvariable=self.out_dir, width=60).grid(row=3, column=0, sticky="w", **pad)
        ttk.Button(frm, text="Browse…", command=self.pick_out_dir).grid(row=3, column=1, sticky="w")

        ttk.Label(frm, text="Filtering (IIR; safe for short recordings)").grid(row=4, column=0, sticky="w", pady=(12, 0))
        ffrm = ttk.Frame(frm)
        ffrm.grid(row=5, column=0, columnspan=2, sticky="w", **pad)

        ttk.Label(ffrm, text="Bandpass (Hz):").grid(row=0, column=0, sticky="w")
        ttk.Entry(ffrm, textvariable=self.hpass, width=8).grid(row=0, column=1, sticky="w", padx=(6, 12))
        ttk.Label(ffrm, text="to").grid(row=0, column=2, sticky="w")
        ttk.Entry(ffrm, textvariable=self.lpass, width=8).grid(row=0, column=3, sticky="w", padx=(6, 12))

        ttk.Label(ffrm, text="Notch (Hz):").grid(row=0, column=4, sticky="w")
        ttk.Entry(ffrm, textvariable=self.notch_lo, width=6).grid(row=0, column=5, sticky="w", padx=(6, 6))
        ttk.Label(ffrm, text="to").grid(row=0, column=6, sticky="w")
        ttk.Entry(ffrm, textvariable=self.notch_hi, width=6).grid(row=0, column=7, sticky="w")

        ttk.Label(frm, text="Epoching").grid(row=6, column=0, sticky="w", pady=(12, 0))
        efrm = ttk.Frame(frm)
        efrm.grid(row=7, column=0, columnspan=2, sticky="w", **pad)

        ttk.Label(efrm, text="tmin (s):").grid(row=0, column=0, sticky="w")
        ttk.Entry(efrm, textvariable=self.tmin, width=8).grid(row=0, column=1, sticky="w", padx=(6, 12))
        ttk.Label(efrm, text="tmax (s):").grid(row=0, column=2, sticky="w")
        ttk.Entry(efrm, textvariable=self.tmax, width=8).grid(row=0, column=3, sticky="w", padx=(6, 12))
        ttk.Checkbutton(efrm, text="Baseline correct (tmin to 0)", variable=self.baseline_on).grid(row=0, column=4, sticky="w")

        ttk.Button(frm, text="Run batch preprocessing", command=self.run).grid(row=8, column=0, sticky="w", pady=(12, 0))
        ttk.Label(frm, textvariable=self.status, wraplength=560).grid(row=9, column=0, columnspan=2, sticky="w", pady=(6, 0))

    def pick_in_dir(self):
        d = filedialog.askdirectory(title="Select folder with raw + event words files")
        if d:
            self.in_dir.set(d)
            if not self.out_dir.get().strip():
                self.out_dir.set(os.path.join(d, "preprocessed_out"))

    def pick_out_dir(self):
        d = filedialog.askdirectory(title="Select output folder")
        if d:
            self.out_dir.set(d)

    def run(self):
        in_dir = self.in_dir.get().strip()
        out_dir = self.out_dir.get().strip()

        if not in_dir or not os.path.isdir(in_dir):
            messagebox.showerror("Error", "Please select a valid input folder.")
            return
        if not out_dir:
            messagebox.showerror("Error", "Please select a valid output folder.")
            return

        pairs = find_pairs(in_dir)
        if not pairs:
            messagebox.showerror(
                "Error",
                "No file pairs found.\nExpected filenames like:\n  'Alec raw.txt'\n  'Alec event words.txt'"
            )
            return

        notch_lo = int(self.notch_lo.get())
        notch_hi = int(self.notch_hi.get())
        if notch_hi < notch_lo:
            messagebox.showerror("Error", "Notch high must be >= notch low.")
            return

        params = {
            "sfreq": DEFAULT_SFREQ,
            "ch_names": DEFAULT_CH_NAMES,
            "hpass": float(self.hpass.get()),
            "lpass": float(self.lpass.get()),
            "notch_freqs": np.arange(notch_lo, notch_hi + 1, 1),
            "tmin": float(self.tmin.get()),
            "tmax": float(self.tmax.get()),
            "baseline_on": bool(self.baseline_on.get()),
            "ica_n_components": DEFAULT_ICA_N_COMPONENTS,
            "ica_method": DEFAULT_ICA_METHOD,
            "ica_random_state": DEFAULT_ICA_RANDOM_STATE,
        }

        summary = []
        errors = 0
        for idx, (name, raw_path, event_path) in enumerate(pairs, start=1):
            self.status.set(f"Processing {idx}/{len(pairs)}: {name}")
            self.update_idletasks()

            try:
                res = process_one(self, name, raw_path, event_path, out_dir, params)
                # Normalize result rows so the CSV is easy to audit
                row = dict(res)
                if row.get("skipped", False):
                    row.setdefault("status", "skipped")
                    row.setdefault("error", row.get("reason", ""))
                else:
                    row.setdefault("status", "ok")
                    row.setdefault("error", "")
                summary.append(row)
            except Exception as e:
                errors += 1
                summary.append({"name": name, "status": "error", "skipped": False, "error": repr(e)})
                # Do NOT stop the batch; keep going
                continue
        os.makedirs(out_dir, exist_ok=True)
        batch_csv = os.path.join(out_dir, "batch_summary.csv")
        pd.DataFrame(summary).to_csv(batch_csv, index=False)

        self.status.set(f"Done. Batch summary saved:\n{batch_csv}")
        if errors:
            messagebox.showwarning("Finished (with issues)",
                                   f"Batch processing complete, but {errors} file(s) had errors.\n\n"
                                   f"See the batch summary for details:\n{batch_csv}")
        else:
            messagebox.showinfo("Finished", f"Batch processing complete.\n\nBatch summary:\n{batch_csv}")


if __name__ == "__main__":
    try:
        import ctypes
        ctypes.windll.shcore.SetProcessDpiAwareness(1)
    except Exception:
        pass

    app = App()
    app.mainloop()
